;;;; foreign.test --- FFI.           -*- mode: scheme; coding: utf-8; -*-
;;;;
;;;; 	Copyright (C) 2010, 2011, 2012, 2013, 2017, 2021 Free Software Foundation, Inc.
;;;;
;;;; This library is free software; you can redistribute it and/or
;;;; modify it under the terms of the GNU Lesser General Public
;;;; License as published by the Free Software Foundation; either
;;;; version 3 of the License, or (at your option) any later version.
;;;;
;;;; This library is distributed in the hope that it will be useful,
;;;; but WITHOUT ANY WARRANTY; without even the implied warranty of
;;;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
;;;; Lesser General Public License for more details.
;;;;
;;;; You should have received a copy of the GNU Lesser General Public
;;;; License along with this library; if not, write to the Free Software
;;;; Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA

;;;
;;; See also ../standalone/test-ffi for FFI tests.
;;;

(define-module (test-foreign)
  #:use-module (system foreign-library)
  #:use-module (system foreign)
  #:use-module (rnrs bytevectors)
  #:use-module (srfi srfi-1)
  #:use-module (srfi srfi-26)
  #:use-module (ice-9 format)
  #:use-module (test-suite lib))


(with-test-prefix "foreign-library-pointer"

  (pass-if-exception
   "error message"
   ;; The error comes from dlsym, which is system-dependent.
   '(misc-error . "")
   (foreign-library-pointer #f "does_not_exist___")))


(with-test-prefix "null pointer"

  (pass-if "pointer?"
    (pointer? %null-pointer))

  (pass-if "zero"
    (= 0 (pointer-address %null-pointer)))

  (pass-if "null pointer identity"
    (eq? %null-pointer (make-pointer 0)))

  (pass-if "null-pointer? %null-pointer"
    (null-pointer? %null-pointer))

  (pass-if-exception "dereference-pointer %null-pointer"
    exception:null-pointer-error
    (dereference-pointer %null-pointer))

  (pass-if-exception "pointer->bytevector %null-pointer"
    exception:null-pointer-error
    (pointer->bytevector %null-pointer 7)))


(with-test-prefix "make-pointer"

  (pass-if "pointer?"
    (pointer? (make-pointer 123)))

  (pass-if "address preserved"
    (= 123 (pointer-address (make-pointer 123))))

  (pass-if "equal?"
    (equal? (make-pointer 123) (make-pointer 123)))

  (pass-if "equal? modulo finalizer"
    (let ((finalizer (false-if-exception
                      (foreign-library-pointer #f "scm_is_pair"))))
      (if (not finalizer)
          (throw 'unresolved)               ;  Windows or a static build
          (equal? (make-pointer 123)
                  (make-pointer 123 finalizer)))))

  (pass-if "equal? modulo finalizer (set-pointer-finalizer!)"
    (let ((finalizer (false-if-exception
                      (foreign-library-pointer #f "scm_is_pair")))
          (ptr       (make-pointer 123)))
      (if (not finalizer)
          (throw 'unresolved)                ; Windows or a static build
          (begin
            (set-pointer-finalizer! ptr finalizer)
            (equal? (make-pointer 123) ptr)))))

  (pass-if "not equal?"
    (not (equal? (make-pointer 123) (make-pointer 456)))))


(with-test-prefix "pointer<->scm"

  (pass-if "immediates"
    (equal? (pointer->scm (scm->pointer #\newline))
            #\newline))

  (pass-if "non-immediates"
    (equal? (pointer->scm (scm->pointer "Hello, world!"))
            "Hello, world!")))


(define-wrapped-pointer-type foo
  foo?
  wrap-foo unwrap-foo
  (lambda (x p)
    (format p "#<foo! ~a>" (pointer-address (unwrap-foo x)))))

(with-test-prefix "define-wrapped-pointer-type"

  (pass-if "foo?"
    (foo? (wrap-foo %null-pointer)))

  (pass-if "unwrap-foo"
    (let ((p (make-pointer 123)))
      (eq? p (unwrap-foo (wrap-foo p)))))

  (pass-if "identity"
    (let ((p1 (make-pointer 123))
          (p2 (make-pointer 123)))
      (eq? (wrap-foo p1)
           (wrap-foo p2))))

  (pass-if "printer"
    (string=? "#<foo! 123>"
              (with-output-to-string
                (lambda ()
                  (write (wrap-foo (make-pointer 123))))))))


(with-test-prefix "pointer<->bytevector"

  (pass-if "bijection"
    (let ((bv #vu8(0 1 2 3 4 5 6 7)))
      (equal? (pointer->bytevector (bytevector->pointer bv)
                                   (bytevector-length bv))
              bv)))

  (pass-if "pointer from bits"
    (let* ((bytes (iota (sizeof '*)))
           (bv    (u8-list->bytevector bytes))
           (fold  (case (native-endianness)
                    ((little) fold-right)
                    ((big)    fold)
                    (else     (error "unsupported endianness")))))
      (= (pointer-address
          (make-pointer (bytevector-uint-ref bv 0 (native-endianness)
                                             (sizeof '*))))
         (fold (lambda (byte address)
                 (+ byte (* 256 address)))
               0
               bytes))))

  (pass-if "dereference-pointer"
    (let* ((bytes (iota (sizeof '*)))
           (bv    (u8-list->bytevector bytes))
           (fold  (case (native-endianness)
                    ((little) fold-right)
                    ((big)    fold)
                    (else     (error "unsupported endianness")))))
      (= (pointer-address
          (dereference-pointer (bytevector->pointer bv)))
         (fold (lambda (byte address)
                 (+ byte (* 256 address)))
               0
               bytes)))))


(with-test-prefix "pointer<->string"

  (pass-if-exception "%default-port-conversion-strategy is error"
    exception:encoding-error
    (let ((s "χαοσ"))
      (with-fluids ((%default-port-conversion-strategy 'error))
        (string->pointer s "ISO-8859-1"))))

  (pass-if "%default-port-conversion-strategy is escape"
    (let ((s "teĥniko"))
      (equal? (with-fluids ((%default-port-conversion-strategy 'escape))
                (pointer->string (string->pointer s "ISO-8859-1")))
              (format #f "te\\u~4,'0xniko"
                      (char->integer #\ĥ)))))

  (pass-if "%default-port-conversion-strategy is substitute"
    (let ((s "teĥniko")
          (member (negate (negate member))))
      (member (with-fluids ((%default-port-conversion-strategy 'substitute))
                (pointer->string (string->pointer s "ISO-8859-1")))
              '("te?niko"

                ;; This form is found on FreeBSD 8.2 and Darwin 10.8.0.
                "te^hniko"))))

  (pass-if "bijection"
    (let ((s "hello, world"))
      (string=? s (pointer->string (string->pointer s)))))

  (pass-if "bijection [latin1]"
    (with-latin1-locale
      (let ((s "Szép jó napot!"))
        (string=? s (pointer->string (string->pointer s))))))

  (pass-if "bijection, utf-8"
    (let ((s "hello, world"))
      (string=? s (pointer->string (string->pointer s "utf-8")
                                   -1 "utf-8"))))

  (pass-if "bijection, utf-8 [latin1]"
    (let ((s "Szép jó napot!"))
      (string=? s (pointer->string (string->pointer s "utf-8")
                                   -1 "utf-8")))))



(with-test-prefix "pointer->procedure"

  (pass-if-exception "object instead of pointer"
    exception:wrong-type-arg
    (let ((p (pointer->procedure '* %null-pointer '(*))))
      (p #t))))


(with-test-prefix "procedure->pointer"

  (define qsort
    ;; Bindings for libc's `qsort' function.
    ;; On some platforms, such as MinGW, `qsort' is visible only if
    ;; linking with `-export-dynamic'.  Just skip these tests when it's
    ;; not visible.
    (false-if-exception
     (foreign-library-function
      (cond
       ((string-contains %host-type "cygwin")
        ;; On Cygwin, load-foreign-library does not search recursively
        ;; into linked DLLs. Thus, one needs to link to the core C
        ;; library DLL explicitly.
        "cygwin1")
       (else #f))
      "qsort" #:arg-types (list '* size_t size_t '*))))

  (define (dereference-pointer-to-byte ptr)
    (let ((b (pointer->bytevector ptr 1)))
      (bytevector-u8-ref b 0)))

  (define input
    '(7 1 127 3 5 4 77 2 9 0))

  (pass-if "qsort"
    (if (and qsort (defined? 'procedure->pointer))
        (let* ((called? #f)
               (cmp     (lambda (x y)
                          (set! called? #t)
                          (- (dereference-pointer-to-byte x)
                             (dereference-pointer-to-byte y))))
               (ptr     (procedure->pointer int cmp (list '* '*)))
               (bv      (u8-list->bytevector input)))
          (qsort (bytevector->pointer bv) (bytevector-length bv) 1
                 (procedure->pointer int cmp (list '* '*)))
          (and called?
               (equal? (bytevector->u8-list bv)
                       (sort input <))))
        (throw 'unresolved)))

  (pass-if-exception "qsort, wrong return type"
    exception:wrong-type-arg

    (if (and qsort (defined? 'procedure->pointer))
        (let* ((cmp     (lambda (x y) #f)) ; wrong return type
               (ptr     (procedure->pointer int cmp (list '* '*)))
               (bv      (u8-list->bytevector input)))
          (qsort (bytevector->pointer bv) (bytevector-length bv) 1
                 (procedure->pointer int cmp (list '* '*)))
          #f)
        (throw 'unresolved)))

  (pass-if-exception "qsort, wrong arity"
    exception:wrong-num-args

    (if (and qsort (defined? 'procedure->pointer))
        (let* ((cmp     (lambda (x y z) #f)) ; wrong arity
               (ptr     (procedure->pointer int cmp (list '* '*)))
               (bv      (u8-list->bytevector input)))
          (qsort (bytevector->pointer bv) (bytevector-length bv) 1
                 (procedure->pointer int cmp (list '* '*)))
          #f)
        (throw 'unresolved)))

  (pass-if "bijection"
    (if (defined? 'procedure->pointer)
        (let* ((proc  (lambda (x y z)
                        (+ x y z 0.0)))
               (ret   double)
               (args  (list float int16 double))
               (proc* (pointer->procedure ret
                                          (procedure->pointer ret proc args)
                                          args))
               (arg1  (map (cut / <> 2.0) (iota 123)))
               (arg2  (iota 123 32000))
               (arg3  (map (cut / <> 4.0) (iota 123 100 4))))
          (equal? (map proc arg1 arg2 arg3)
                  (map proc* arg1 arg2 arg3)))
        (throw 'unresolved)))

  (pass-if "procedures returning a pointer"
    (if (defined? 'procedure->pointer)
        (let* ((called? #f)
               (proc    (lambda (i) (set! called? #t) (make-pointer i)))
               (pointer (procedure->pointer '* proc (list int)))
               (proc*   (pointer->procedure '* pointer (list int)))
               (result  (proc* 777)))
          (and called? (equal? result (make-pointer 777))))
        (throw 'unresolved)))

  (pass-if "procedures returning void"
    (if (defined? 'procedure->pointer)
        (let* ((called? #f)
               (proc    (lambda () (set! called? #t)))
               (pointer (procedure->pointer void proc '()))
               (proc*   (pointer->procedure void pointer '())))
          (proc*)
          called?)
        (throw 'unresolved)))

  (pass-if "procedure is retained"
    ;; The lambda passed to `procedure->pointer' must remain live.
    (if (defined? 'procedure->pointer)
        (let* ((ptr   (procedure->pointer int
                                          (lambda (x) (+ x 7))
                                          (list int)))
               (procs (unfold (cut >= <> 10000)
                              (lambda (i)
                                (pointer->procedure int ptr (list int)))
                              1+
                              0)))
          (gc) (gc) (gc)
          (every (cut = <> 9)
                 (map (lambda (f) (f 2)) procs)))
        (throw 'unresolved)))

  (pass-if "arity"
    (if (and qsort (defined? 'procedure->pointer))
        (equal? '(4 0 #f) (procedure-minimum-arity qsort))
        (throw 'unresolved))))


(with-test-prefix "structs"

  (pass-if "sizeof { int8, double }"
    (= (sizeof (list int8 double))
       (+ (alignof double) (sizeof double))))

  (pass-if "sizeof { double, int8 }"
    (= (sizeof (list double int8))
       (+ (alignof double) (sizeof double))))

  (pass-if "sizeof { short, int, long, pointer }"
    (let ((layout (list short int long '*)))
      (>= (sizeof layout)
          (reduce + 0.0 (map sizeof layout)))))

  (pass-if "alignof { int8, double, int8 }"
    ;; alignment of the most strictly aligned component
    (let ((layout (list int8 double int8)))
      (= (alignof layout) (alignof double))))

  (pass-if "parse-c-struct"
    (let ((layout (list int64 uint8))
          (data   (list -300 43)))
      (equal? (parse-c-struct (make-c-struct layout data)
                              layout)
              data)))

  (pass-if "alignment constraints honored"
    (let ((layout (list int8 double))
          (data   (list -7 3.14)))
      (equal? (parse-c-struct (make-c-struct layout data)
                              layout)
              data)))

  (pass-if "int8, pointer"
    (let ((layout (list uint8 '*))
          (data   (list 222 (make-pointer 7777))))
      (equal? (parse-c-struct (make-c-struct layout data)
                              layout)
              data)))

  (pass-if "unsigned-long, int8, size_t"
    (let ((layout (list unsigned-long int8 size_t))
          (data   (list (expt 2 17) -128 (expt 2 18))))
      (equal? (parse-c-struct (make-c-struct layout data)
                              layout)
              data)))

  (pass-if "long, int, pointer"
    (let ((layout (list long int '*))
          (data   (list (- (expt 2 17)) -222 (make-pointer 777))))
      (equal? (parse-c-struct (make-c-struct layout data)
                              layout)
              data)))

  (pass-if "int8, pointer, short, double"
    (let ((layout (list int8 '* short double))
          (data   (list 77 %null-pointer -42 3.14)))
      (equal? (parse-c-struct (make-c-struct layout data)
                              layout)
              data)))

  (pass-if "int8, { int8, double, int8 }, int16"
    (let ((layout (list int8 (list int8 double int8) int16))
          (data   (list 77 (list 42 4.2 55) 88)))
      (equal? (parse-c-struct (make-c-struct layout data)
                              layout)
              data))))


(with-test-prefix "lib->cyg"
  (pass-if "name is #f"
    (equal? #f (lib->cyg #f)))

  (pass-if "name too short"
    (string=? (lib->cyg "a") "a"))

  (pass-if "name starts with 'lib'"
    (string=? (lib->cyg "libfoo.dll") "cygfoo.dll"))

  (pass-if "name contains 'lib'"
    (string=? (lib->cyg "foolib.dll") "foolib.dll"))

  (pass-if "name doesn't contain 'lib'"
    (string=? (lib->cyg "foobar.dll") "foobar.dll"))

  (pass-if "name in path too short"
    (string=? (lib->cyg "/lib/a") "/lib/a"))

  (pass-if "name in path starts with 'lib'"
    (string=? (lib->cyg "/lib/libfoo.dll") "/lib/cygfoo.dll"))

  (pass-if "name in path contains 'lib'"
    (string=? (lib->cyg "/lib/foolib.dll") "/lib/foolib.dll"))

  (pass-if "name in path doesn't contain 'lib'"
    (string=? (lib->cyg "/lib/foobar.dll") "/lib/foobar.dll"))

  (pass-if "name in windows path starts with 'lib'"
    (string=? (lib->cyg "c:\\lib\\libfoo.dll") "c:\\lib\\cygfoo.dll"))

  (pass-if "name in windows path doesn't contain 'lib'"
    (string=? (lib->cyg "c:\\lib\\foobar.dll") "c:\\lib\\foobar.dll")))
